{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "# default_exp problem_types.vector_fit\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adding new problem weibo_fake_ner, problem type: seq_tag\n",
      "Adding new problem weibo_cws, problem type: seq_tag\n",
      "Adding new problem weibo_fake_multi_cls, problem type: multi_cls\n",
      "Adding new problem weibo_fake_cls, problem type: cls\n",
      "Adding new problem weibo_masklm, problem type: masklm\n",
      "Adding new problem weibo_pretrain, problem type: pretrain\n",
      "Adding new problem weibo_fake_regression, problem type: regression\n",
      "Adding new problem weibo_fake_vector_fit, problem type: vector_fit\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/m3tl/m3tl/read_write_tfrecord.py:83: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  elif np.issubdtype(type(feature), np.float):\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:sampling weights: \n",
      "INFO:tensorflow:weibo_fake_cls_weibo_fake_ner_weibo_fake_regression_weibo_fake_vector_fit: 0.6666666666666666\n",
      "INFO:tensorflow:weibo_fake_multi_cls: 0.16666666666666666\n",
      "INFO:tensorflow:weibo_masklm: 0.16666666666666666\n"
     ]
    }
   ],
   "source": [
    "# test setup\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from m3tl.test_base import TestBase\n",
    "from m3tl.input_fn import train_eval_input_fn\n",
    "from m3tl.test_base import test_top_layer\n",
    "test_base = TestBase()\n",
    "params = test_base.params\n",
    "\n",
    "hidden_dim = params.bert_config.hidden_size\n",
    "\n",
    "train_dataset = train_eval_input_fn(params=params)\n",
    "one_batch = next(train_dataset.as_numpy_iterator())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VectorFit(vector_fit)\n",
    "\n",
    "This module includes neccessary part to register vector_fit problem type."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "from typing import Dict, Tuple\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from m3tl.base_params import BaseParams\n",
    "from m3tl.problem_types.utils import (empty_tensor_handling_loss,\n",
    "                                      nan_loss_handling)\n",
    "from m3tl.special_tokens import PREDICT\n",
    "from m3tl.utils import get_phase\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Top Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def cosine_wrapper(labels, logits, from_logits=True):\n",
    "    return tf.keras.losses.cosine_similarity(labels, logits)\n",
    "\n",
    "\n",
    "class VectorFit(tf.keras.Model):\n",
    "    def __init__(self, params: BaseParams, problem_name: str) -> None:\n",
    "        super(VectorFit, self).__init__(name=problem_name)\n",
    "        self.params = params\n",
    "        self.problem_name = problem_name\n",
    "        self.num_classes = self.params.get_problem_info(problem=problem_name, info_name='num_classes')\n",
    "        self.dense = tf.keras.layers.Dense(self.num_classes)\n",
    "\n",
    "    def call(self, inputs: Tuple[Dict]):\n",
    "        mode = get_phase()\n",
    "        feature, hidden_feature = inputs\n",
    "        pooled_hidden = hidden_feature['pooled']\n",
    "\n",
    "        logits = self.dense(pooled_hidden)\n",
    "        if mode != PREDICT:\n",
    "            # this is actually a vector\n",
    "            label = feature['{}_label_ids'.format(self.problem_name)]\n",
    "\n",
    "            loss = empty_tensor_handling_loss(label, logits, cosine_wrapper)\n",
    "            loss = nan_loss_handling(loss)\n",
    "            self.add_loss(loss)\n",
    "\n",
    "            self.add_metric(tf.math.negative(\n",
    "                loss), name='{}_cos_sim'.format(self.problem_name), aggregation='mean')\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing VectorFit\n"
     ]
    }
   ],
   "source": [
    "test_top_layer(VectorFit, problem='weibo_fake_vector_fit', params=params, sample_features=one_batch, hidden_dim=hidden_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get or make label encoder function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def vector_fit_get_or_make_label_encoder_fn(params: BaseParams, problem, mode, label_list, *args, **kwargs):\n",
    "    if label_list:\n",
    "        # set params num_classes for this problem\n",
    "        label_array = np.array(label_list)\n",
    "        params.set_problem_info(problem=problem, info_name='num_classes', info=label_array.shape[-1])\n",
    "    return None\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Label handing function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def vector_fit_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):\n",
    "    # return label_id and label mask\n",
    "    label_id = np.array(target, dtype='float32')\n",
    "    return label_id, None\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.10 64-bit ('base': conda)",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
